package org.zjt.flink.wordcount;
import org.apache.flink.api.common.functions.FlatMapFunction;
import org.apache.flink.api.common.functions.GroupReduceFunction;
import org.apache.flink.api.common.functions.MapPartitionFunction;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.ExecutionEnvironment;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.util.Collector;

import java.util.Map;
import java.util.SortedMap;
import java.util.TreeMap;
/**
 * Description:
 *
 * @author juntao.zhang
 * Date: 2018-09-29 下午5:15
 * @see
 */
public class WordCountTopN {
    public static void main(String[] args) throws Exception {

        // set up the execution environment
        final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();

        // get input data
        DataSet<String> text = env.fromElements(
                "14159265358979323846264338327950288419716939937510",
                "58209749445923078164062862089986280348253421170679",
                "82148086513282306647093844609550582231725359408128",
                "48111745028410270193852110555964462294895493038196",
                "44288109756659334461284756482337867831652712019091",
                "45648566923460348610454326648213393607260249141273",
                "72458700660631558817488152092096282925409171536436",
                "78925903600113305305488204665213841469519415116094",
                "33057270365759591953092186117381932611793105118548",
                "07446237996274956735188575272489122793818301194912",
                "98336733624406566430860213949463952247371907021798",
                "60943702770539217176293176752384674818467669405132",
                "00056812714526356082778577134275778960917363717872",
                "14684409012249534301465495853710507922796892589235",
                "42019956112129021960864034418159813629774771309960",
                "51870721134999999837297804995105973173281609631859",
                "50244594553469083026425223082533446850352619311881",
                "71010003137838752886587533208381420617177669147303",
                "59825349042875546873115956286388235378759375195778",
                "18577805321712268066130019278766111959092164201989"
        );

        DataSet<Tuple2<String, Integer>> counts = text
                // split up the lines in pairs (2-tuples) containing: (word,1)
                .flatMap(new LineSplitter())
                .rebalance()
                // local word count
                .mapPartition(new MapPartitionFunction<Tuple2<String, Integer>, Tuple2<String, Integer>>() {
                    @Override
                    public void mapPartition(Iterable<Tuple2<String, Integer>> words,
                                             Collector<Tuple2<String, Integer>> out) throws Exception {
                        SortedMap<String, Integer> m = new TreeMap<String, Integer>();
                        for (Tuple2<String, Integer> w : words) {
                            Integer current = m.get(w.f0);
                            Integer updated = current == null ? w.f1 : current + w.f1;
                            m.put(w.f0, updated);
                        }

                        for (Map.Entry<String, Integer> e : m.entrySet()) {
                            out.collect(Tuple2.of(e.getKey(), e.getValue()));
                        }
                    }
                })
                // global word count
                .reduceGroup(new GroupReduceFunction<Tuple2<String, Integer>, Tuple2<String, Integer>>() {
                    @Override
                    public void reduce(Iterable<Tuple2<String, Integer>> wordcounts,
                                       Collector<Tuple2<String, Integer>> out) throws Exception {
                        SortedMap<String, Integer> m = new TreeMap<String, Integer>();
                        for (Tuple2<String, Integer> wc : wordcounts) {
                            Integer current = m.get(wc.f0);
                            Integer updated = current == null ? wc.f1 : current + wc.f1;
                            m.put(wc.f0, updated);
                        }

                        for (Map.Entry<String, Integer> e : m.entrySet()) {
                            out.collect(Tuple2.of(e.getKey(), e.getValue()));
                        }
                    }
                });

        DataSet<Tuple2<String, Integer>> topK = counts
                .rebalance()
                // local top-K
                .mapPartition(new MapPartitionFunction<Tuple2<String, Integer>, Tuple2<String, Integer>>() {
                    @Override
                    public void mapPartition(Iterable<Tuple2<String, Integer>> wordcounts,
                                             Collector<Tuple2<String, Integer>> out) throws Exception {
                        SortedMap<Integer, String> topKSoFar = new TreeMap<Integer, String>();
                        for (Tuple2<String, Integer> wc : wordcounts) {
                            String w = wc.f0;
                            Integer c = wc.f1;
                            topKSoFar.put(c, w);
                            if (topKSoFar.size() > 3) {
                                topKSoFar.remove(topKSoFar.firstKey());
                            }
                        }

                        for (Map.Entry<Integer, String> cw : topKSoFar.entrySet()) {
                            out.collect(Tuple2.of(cw.getValue(), cw.getKey()));
                        }
                    }
                })
                // global top-K
                .reduceGroup(new GroupReduceFunction<Tuple2<String, Integer>, Tuple2<String, Integer>>() {
                    @Override
                    public void reduce(Iterable<Tuple2<String, Integer>> topList,
                                       Collector<Tuple2<String, Integer>> out) throws Exception {
                        SortedMap<Integer, String> topKSoFar = new TreeMap<Integer, String>();
                        for (Tuple2<String, Integer> wc : topList) {
                            String w = wc.f0;
                            Integer c = wc.f1;
                            topKSoFar.put(c, w);
                            if (topKSoFar.size() > 3) {
                                topKSoFar.remove(topKSoFar.firstKey());
                            }
                        }

                        for (Map.Entry<Integer, String> cw : topKSoFar.entrySet()) {
                            out.collect(Tuple2.of(cw.getValue(), cw.getKey()));
                        }
                    }
                });

        // execute and print result
        topK.print();

        env.setParallelism(4);
        env.execute();

    }


    public static final class LineSplitter implements FlatMapFunction<String, Tuple2<String, Integer>> {
        @Override
        public void flatMap(String value, Collector<Tuple2<String, Integer>> out) {
            String[] tokens = value.split("");

            for (String token : tokens) {
                if (token.length() > 0) {
                    out.collect(new Tuple2<String, Integer>(token, 1));
                }
            }
        }
    }
}
